/* ************************************************************************
 * Copyright (c) 2018 Advanced Micro Devices, Inc.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 *
 * ************************************************************************ */

#include "../../utils/def.hpp"
#include "preconditioner_blockprecond.hpp"
#include "preconditioner.hpp"
#include "../solver.hpp"

#include "../../base/local_matrix.hpp"

#include "../../base/local_vector.hpp"

#include "../../utils/allocate_free.hpp"
#include "../../utils/log.hpp"

#include <complex>

namespace rocalution {

template <class OperatorType, class VectorType, typename ValueType>
BlockPreconditioner<OperatorType, VectorType, ValueType>::BlockPreconditioner()
{
    log_debug(this, "BlockPreconditioner::BlockPreconditioner()", "default constructor");

    this->num_blocks_  = 0;
    this->block_sizes_ = NULL;

    this->op_mat_format_      = false;
    this->precond_mat_format_ = CSR;

    this->diag_solve_ = false;
    this->A_last_     = NULL;
}

template <class OperatorType, class VectorType, typename ValueType>
BlockPreconditioner<OperatorType, VectorType, ValueType>::~BlockPreconditioner()
{
    log_debug(this, "BlockPreconditioner::~BlockPreconditioner()", "destructor");

    this->Clear();
}

template <class OperatorType, class VectorType, typename ValueType>
void BlockPreconditioner<OperatorType, VectorType, ValueType>::Clear(void)
{
    log_debug(this, "BlockPreconditioner::Clear()", this->build_);

    if(this->build_ == true)
    {
        for(int i = 0; i < this->num_blocks_; ++i)
        {
            this->x_block_[i]->Clear();
            this->tmp_block_[i]->Clear();
            delete this->x_block_[i];
            delete this->tmp_block_[i];

            if(this->D_solver_[i] != NULL)
            {
                this->D_solver_[i]->Clear();
                this->D_solver_[i] = NULL;
            }

            for(int j = 0; j < this->num_blocks_; ++j)
            {
                delete this->A_block_[i][j];
            }

            delete[] this->A_block_[i];
            this->A_block_[i] = NULL;
        }

        delete[] this->x_block_;
        delete[] this->tmp_block_;
        delete[] this->A_block_;
        delete[] this->D_solver_;

        free_host(&this->block_sizes_);
        this->num_blocks_ = 0;

        this->op_mat_format_      = false;
        this->precond_mat_format_ = CSR;

        this->permutation_.Clear();
        this->x_.Clear();

        this->build_ = false;
    }
}

template <class OperatorType, class VectorType, typename ValueType>
void BlockPreconditioner<OperatorType, VectorType, ValueType>::Print(void) const
{
    if(this->build_ == true)
    {
        LOG_INFO("BlockPreconditioner with " << this->num_blocks_ << " blocks:");

        for(int i = 0; i < this->num_blocks_; ++i)
        {
            this->D_solver_[i]->Print();
        }
    }
    else
    {
        LOG_INFO("BlockPreconditioner (I)LU preconditioner");
    }
}

template <class OperatorType, class VectorType, typename ValueType>
void BlockPreconditioner<OperatorType, VectorType, ValueType>::Set(
    int n, const int* size, Solver<OperatorType, VectorType, ValueType>** D_Solver)
{
    log_debug(this, "BlockPreconditioner::Set()", n, size, D_Solver);

    assert(this->build_ == false);

    assert(n > 0);
    this->num_blocks_ = n;

    this->block_sizes_ = new int[n];
    this->D_solver_    = new Solver<OperatorType, VectorType, ValueType>*[n];

    for(int i = 0; i < this->num_blocks_; ++i)
    {
        this->block_sizes_[i] = size[i];
        this->D_solver_[i]    = D_Solver[i];
    }
}

template <class OperatorType, class VectorType, typename ValueType>
void BlockPreconditioner<OperatorType, VectorType, ValueType>::SetDiagonalSolver(void)
{
    log_debug(this, "BlockPreconditioner::SetDiagonalSolver()");

    this->diag_solve_ = true;
}

template <class OperatorType, class VectorType, typename ValueType>
void BlockPreconditioner<OperatorType, VectorType, ValueType>::SetLSolver(void)
{
    log_debug(this, "BlockPreconditioner::SetLSolver()");

    this->diag_solve_ = false;
}

template <class OperatorType, class VectorType, typename ValueType>
void BlockPreconditioner<OperatorType, VectorType, ValueType>::SetExternalLastMatrix(
    const OperatorType& mat)
{
    log_debug(this, "BlockPreconditioner::SetExternalLastMatrix()", (const void*&)mat);

    assert(this->build_ == false);
    assert(this->A_last_ == NULL);

    this->A_last_ = new OperatorType;
    this->A_last_->CloneBackend(mat);
    this->A_last_->CopyFrom(mat);
}

template <class OperatorType, class VectorType, typename ValueType>
void BlockPreconditioner<OperatorType, VectorType, ValueType>::SetPermutation(
    const LocalVector<int>& perm)
{
    log_debug(this, "BlockPreconditioner::SetPermutation()", (const void*&)perm);

    assert(perm.GetSize() > 0);

    this->permutation_.CopyFrom(perm);
}

template <class OperatorType, class VectorType, typename ValueType>
void BlockPreconditioner<OperatorType, VectorType, ValueType>::Build(void)
{
    log_debug(this, "BlockPreconditioner::Build()", this->build_, " #*# begin");

    assert(this->build_ == false);
    this->build_ = true;

    assert(this->op_ != NULL);

    this->x_block_   = new VectorType*[this->num_blocks_];
    this->tmp_block_ = new VectorType*[this->num_blocks_];

    for(int i = 0; i < this->num_blocks_; ++i)
    {
        this->x_block_[i] = new VectorType;
        this->x_block_[i]->CloneBackend(*this->op_); // clone backend
        this->x_block_[i]->Allocate("Diagonal preconditioners", this->block_sizes_[i]);

        this->tmp_block_[i] = new VectorType;
        this->tmp_block_[i]->CloneBackend(*this->op_); // clone backend
        this->tmp_block_[i]->Allocate("Diagonal preconditioners", this->block_sizes_[i]);

        //    LOG_INFO("Block=" << i << " size=" << this->block_sizes_[i]);
    }

    int* offsets = NULL;
    allocate_host(this->num_blocks_ + 1, &offsets);

    offsets[0] = 0;
    for(int k = 0; k < this->num_blocks_; ++k)
    {
        offsets[k + 1] = this->block_sizes_[k];
    }

    // sum up
    for(int i = 0; i < this->num_blocks_; ++i)
    {
        offsets[i + 1] += offsets[i];
    }

    this->A_block_ = new OperatorType**[this->num_blocks_];
    for(int k = 0; k < this->num_blocks_; ++k)
    {
        this->A_block_[k] = new OperatorType*[this->num_blocks_];
    }

    for(int k = 0; k < this->num_blocks_; ++k)
    {
        for(int j = 0; j < this->num_blocks_; ++j)
        {
            this->A_block_[k][j] = new OperatorType;
            this->A_block_[k][j]->CloneBackend(*this->op_);
        }
    }

    if(this->permutation_.GetSize() > 0)
    {
        // with permutation

        assert(this->permutation_.GetSize() == this->op_->GetM());
        assert(this->permutation_.GetSize() == this->op_->GetN());

        this->permutation_.CloneBackend(*this->op_);

        LocalMatrix<ValueType> perm_op_;
        perm_op_.CloneFrom(*this->op_);
        perm_op_.Permute(this->permutation_);

        perm_op_.ExtractSubMatrices(
            this->num_blocks_, this->num_blocks_, offsets, offsets, this->A_block_);

        this->x_.CloneBackend(*this->op_);
        this->x_.Allocate("x (not permuted)", this->op_->GetM());
    }
    else
    {
        // without permutation

        this->op_->ExtractSubMatrices(
            this->num_blocks_, this->num_blocks_, offsets, offsets, this->A_block_);
    }

    free_host(&offsets);

    if(this->A_last_ != NULL)
    {
        assert(this->A_block_[this->num_blocks_ - 1][this->num_blocks_ - 1]->GetM() ==
               this->A_last_->GetM());
        assert(this->A_block_[this->num_blocks_ - 1][this->num_blocks_ - 1]->GetN() ==
               this->A_last_->GetN());

        //    this->A_block_[this->num_blocks_-1][this->num_blocks_-1]->Info();
        this->A_block_[this->num_blocks_ - 1][this->num_blocks_ - 1]->Clear();
        delete this->A_block_[this->num_blocks_ - 1][this->num_blocks_ - 1];

        this->A_block_[this->num_blocks_ - 1][this->num_blocks_ - 1] = this->A_last_;
        //    this->A_block_[this->num_blocks_-1][this->num_blocks_-1]->Info();
        this->A_last_ = NULL;
    }

    for(int i = 0; i < this->num_blocks_; ++i)
    {
        this->D_solver_[i]->SetOperator(*this->A_block_[i][i]);
        this->D_solver_[i]->Build();
    }

    for(int i = 0; i < this->num_blocks_; ++i)
    {
        // Clean U part
        for(int j = i + 1; j < this->num_blocks_; ++j)
        {
            this->A_block_[i][j]->Clear();
        }

        // Clean L part if not needed
        if(this->diag_solve_ == true)
        {
            for(int j = 0; j < i; ++j)
            {
                this->A_block_[i][j]->Clear();
            }
        }
    }

    log_debug(this, "BlockPreconditioner::Build()", this->build_, " #*# end");
}

template <class OperatorType, class VectorType, typename ValueType>
void BlockPreconditioner<OperatorType, VectorType, ValueType>::Solve(const VectorType& rhs,
                                                                     VectorType* x)
{
    log_debug(this, ":BlockPreconditioner:Solve()", " #*# begin", (const void*&)rhs, x);

    assert(this->build_ == true);

    // Extract RHS into the solution

    if(this->permutation_.GetSize() > 0)
    {
        assert(this->permutation_.GetSize() == this->x_.GetSize());
        assert(this->op_->GetM() == this->x_.GetSize());
        assert(this->x_.GetSize() == x->GetSize());
        assert(this->x_.GetSize() == rhs.GetSize());

        // with permutation
        this->x_.CopyFromPermute(rhs, this->permutation_);

        int x_offset = 0;
        for(int i = 0; i < this->num_blocks_; ++i)
        {
            this->x_block_[i]->CopyFrom(this->x_, x_offset, 0, this->block_sizes_[i]);

            x_offset += this->block_sizes_[i];
        }
    }
    else
    {
        // without permutaion
        x->CopyFrom(rhs);

        int x_offset = 0;
        for(int i = 0; i < this->num_blocks_; ++i)
        {
            this->x_block_[i]->CopyFrom(*x, x_offset, 0, this->block_sizes_[i]);

            x_offset += this->block_sizes_[i];
        }
    }

    // Solve L
    for(int i = 0; i < this->num_blocks_; ++i)
    {
        if(this->diag_solve_ == false)
        {
            for(int j = 0; j < i; ++j)
            {
                this->A_block_[i][j]->ApplyAdd(
                    *this->x_block_[j], static_cast<ValueType>(-1), this->x_block_[i]);
            }
        }

        this->D_solver_[i]->SolveZeroSol(*this->x_block_[i], this->tmp_block_[i]);

        this->x_block_[i]->CopyFrom(*this->tmp_block_[i]);
    }

    // Insert Solution

    if(this->permutation_.GetSize() > 0)
    {
        // with permutation

        int x_offset = 0;
        for(int i = 0; i < this->num_blocks_; ++i)
        {
            this->x_.CopyFrom(*this->x_block_[i], 0, x_offset, this->block_sizes_[i]);

            x_offset += this->block_sizes_[i];
        }

        x->CopyFromPermuteBackward(this->x_, this->permutation_);
    }
    else
    {
        // without permutaion

        int x_offset = 0;
        for(int i = 0; i < this->num_blocks_; ++i)
        {
            x->CopyFrom(*this->x_block_[i], 0, x_offset, this->block_sizes_[i]);

            x_offset += this->block_sizes_[i];
        }
    }

    log_debug(this, "BlockPreconditioner::Solve()", " #*# end");
}

template <class OperatorType, class VectorType, typename ValueType>
void BlockPreconditioner<OperatorType, VectorType, ValueType>::MoveToHostLocalData_(void)
{
    log_debug(this, "BlockPreconditioner::MoveToHostLocalData_()", this->build_);

    if(this->build_ == true)
    {
        for(int i = 0; i < this->num_blocks_; ++i)
        {
            this->x_block_[i]->MoveToHost();
            this->tmp_block_[i]->MoveToHost();
            this->D_solver_[i]->MoveToHost();

            for(int j = 0; j < this->num_blocks_; ++j)
            {
                this->A_block_[i][j]->MoveToHost();
            }
        }

        this->permutation_.MoveToHost();
        this->x_.MoveToHost();
    }
}

template <class OperatorType, class VectorType, typename ValueType>
void BlockPreconditioner<OperatorType, VectorType, ValueType>::MoveToAcceleratorLocalData_(void)
{
    log_debug(this, "BlockPreconditioner::MoveToAcceleratorLocalData_()", this->build_);

    if(this->build_ == true)
    {
        for(int i = 0; i < this->num_blocks_; ++i)
        {
            this->x_block_[i]->MoveToAccelerator();
            this->tmp_block_[i]->MoveToAccelerator();
            this->D_solver_[i]->MoveToAccelerator();

            for(int j = 0; j < this->num_blocks_; ++j)
            {
                this->A_block_[i][j]->MoveToAccelerator();
            }
        }

        this->permutation_.MoveToAccelerator();
        this->x_.MoveToAccelerator();
    }
}

template class BlockPreconditioner<LocalMatrix<double>, LocalVector<double>, double>;
template class BlockPreconditioner<LocalMatrix<float>, LocalVector<float>, float>;
#ifdef SUPPORT_COMPLEX
template class BlockPreconditioner<LocalMatrix<std::complex<double>>,
                                   LocalVector<std::complex<double>>,
                                   std::complex<double>>;
template class BlockPreconditioner<LocalMatrix<std::complex<float>>,
                                   LocalVector<std::complex<float>>,
                                   std::complex<float>>;
#endif

} // namespace rocalution
